@dataclass
class State(Generic[A, S]):
"""State is a monad."""
action: Callable[[S], Tuple[A, S]]
def fmap(self, f: Callable[[A], B]) -> State[B, S]:
def new_action(s: S) -> Tuple[B, S]:
a, s2 = self.action(s)
return f(a), s2
return State(new_action)
@staticmethod
def pure(b: B) -> State[B, S2]:
def new_action(s: S2) -> Tuple[B, S2]:
return b, s
return State(new_action)
@staticmethod
def read() -> State[S, S]:
def new_action(s: S) -> Tuple[S, S]:
return s, s
return State(new_action)
@staticmethod
def write(s2: S) -> State[None, S]:
def new_action(s: S) -> Tuple[None, S]:
return None, s2
return State(new_action)
def run(self, s: S) -> Tuple[A, S]:
return self.action(s)
@staticmethod
def run_state(sm: State[A, S], s: S) -> Tuple[A, S]:
return sm.action(s)
def flatmap(self: State[A, S],
f: Callable[[A], State[B, S]]) -> State[B, S]:
def new_action(s: S) -> Tuple[B, S]:
a, s2 = self.action(s)
return f(a).action(s2)
return State(new_action)