Source code for tensortrade.agents.trading_agent

import gin
import pandas as pd
import numpy as np

from abc import ABCMeta, abstractmethod
from typing import Union, Callable, List

from tensortrade.environments.trading_environment import TradingEnvironment
from tensortrade.features.feature_pipeline import FeaturePipeline


[docs]@gin.configurable class TradingAgent(object, metaclass=ABCMeta): """An abstract base class for trading agents capable of self tuning, training, and evaluating.""" def __init__(self, env: TradingEnvironment, feature_pipeline: FeaturePipeline): """ Args: env: A `TradingEnvironment` instance for the agent to trade within. feature_pipeline: A `FeaturePipeline` instance of feature transformations. """ self._env = env self._feature_pipeline = feature_pipeline @property def env(self): """A `TradingEnvironment` instance for the agent to trade within.""" return self._env @env.setter def env(self, env: TradingEnvironment): self._env = env @property def feature_pipeline(self): """A `FeaturePipeline` instance of feature transformations.""" return self._feature_pipeline @feature_pipeline.setter def feature_pipeline(self, feature_pipeline: FeaturePipeline): self.feature_pipeline = feature_pipeline
[docs] @gin.configurable @abstractmethod def tune(self, steps_per_train: int, steps_per_test: int, step_cb: Callable[[pd.DataFrame], bool]) -> pd.DataFrame: """Tune the agent's hyper-parameters and feature set for the environment. Args: steps_per_train: The number of steps per training of each hyper-parameter set. steps_per_test: The number of steps per evaluation of each hyper-parameter set. step_cb (optional): A callback function for monitoring progress of the tuning process. step_cb(pd.DataFrame) -> bool: A history of the agent's trading performance is passed on each iteration. If the callback returns `True`, the training process will stop early. Returns: A history of the agent's trading performance during tuning """ raise NotImplementedError
[docs] @gin.configurable @abstractmethod def train(self, steps: int, callback: Callable[[pd.DataFrame], bool]) -> pd.DataFrame: """Train the agent's underlying model on the environment. Args: steps: The number of steps to train the model within the environment. step_cb (optional): A callback function for monitoring progress of the training process. step_cb(pd.DataFrame) -> bool: A history of the agent's trading performance is passed on each iteration. If the callback returns `True`, the training process will stop early. Returns: A history of the agent's trading performance during training """ raise NotImplementedError
[docs] @gin.configurable @abstractmethod def evaluate(self, steps: int, callback: Callable[[pd.DataFrame], bool]) -> pd.DataFrame: """Evaluate the agent's performance within the environment. Args: steps: The number of steps to train the model within the environment. step_cb (optional): A callback function for monitoring progress of the evaluation process. step_cb(pd.DataFrame) -> bool: A history of the agent's trading performance is passed on each iteration. If the callback returns `True`, the training process will stop early. Returns: A history of the agent's trading performance during evaluation """ raise NotImplementedError
[docs] @abstractmethod def get_action(self, observation: pd.DataFrame) -> Union[float, List[float]]: """Determine an action based on a specific observation. Args: observation: A `pandas.DataFrame` corresponding to an observation within the environment. Returns: An action whose type depends on the action space of the environment. """ raise NotImplementedError