Source code for tensortrade.environments.actions.continuous.simple_continuous_strategy

from typing import Union

from tensortrade.environments.actions import ActionStrategy, TradeType
from tensortrade.exchanges import AssetExchange


[docs]class SimpleContinuousStrategy(ActionStrategy): """Simple continuous strategy, that executes trades on a continuous basis."""
[docs] def __init__(self, base_symbol: str = 'USD', asset_symbol: str = 'BTC'): """ # Arguments base_symbol: optional the symbol that defines the notional value asset_symbol: optional the asset symbol """ super().__init__(action_space_shape=(1, 1), continuous_action_space=True) self.base_symbol = base_symbol self.asset_symbol = asset_symbol
[docs] def get_trade(self, action: Union[int, tuple], exchange: AssetExchange): """Suggest a trade to the trading environment # Arguments action: optional number of timesteps used to calculate the trade amount exchange: the AssetExchange # Returns the asset symbol, the type of trade, amount and price """ action_type, trade_amount = action trade_type = TradeType(int(action_type * 3)) current_price = exchange.current_price( symbol=self.asset_symbol, output_symbol=self.base_symbol) commission_percent = exchange.commission_percent base_precision = exchange.base_precision asset_precision = exchange.asset_precision price = current_price amount_to_trade = 0 if trade_type == TradeType.BUY: balance = exchange.balance(self.base_symbol) price_adjustment = 1 + (commission_percent / 100) price = round(current_price * price_adjustment, base_precision) amount_to_trade = round(balance * trade_amount / price, asset_precision) elif trade_type == TradeType.SELL: portfolio = exchange.portfolio() price_adjustment = 1 + (commission_percent / 100) price = round(current_price * price_adjustment, base_precision) amount_to_trade = round(portfolio.get(self.asset_symbol, 0) * trade_amount, asset_precision) return self.asset_symbol, trade_type, amount_to_trade, price