%% Copyright (C) 2024 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
%% Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
%%
%% All Rights Reserved
%%
%% SPDX-License-Identifier: AGPL-3.0-or-later
%%
%% This program is free software; you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as
%% published by the Free Software Foundation; either version 3 of the
%% License, or (at your option) any later version.
%%
%% This program is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%% GNU General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with this program.  If not, see <https://www.gnu.org/licenses/>.
%%
%% Additional Permission under GNU AGPL version 3 section 7:
%%
%% If you modify this Program, or any covered work, by linking or
%% combining it with runtime libraries of Erlang/OTP as released by
%% Ericsson on https://www.erlang.org (or a modified version of these
%% libraries), containing parts covered by the terms of the Erlang Public
%% License (https://www.erlang.org/EPLICENSE), the licensors of this
%% Program grant you additional permission to convey the resulting work
%% without the need to license the runtime libraries of Erlang/OTP under
%% the GNU Affero General Public License. Corresponding Source for a
%% non-source form of such a combination shall include the source code
%% for the parts of the runtime libraries of Erlang/OTP used as well as
%% that of the covered work.

-module(sctp_proxy).
-behaviour(gen_statem).

-export([init/1,
         callback_mode/0,
         connecting/3,
         connected/3,
         code_change/4,
         terminate/3]).
-export([start_link/3,
         send_data/2,
         shutdown/1]).

-include_lib("kernel/include/logger.hrl").
-include_lib("kernel/include/inet.hrl").
-include_lib("kernel/include/inet_sctp.hrl").

-include("s1gw_metrics.hrl").

%% ------------------------------------------------------------------
%% public API
%% ------------------------------------------------------------------

-spec start_link(gen_sctp:assoc_id(),
                 sctp_client:loc_rem_addr(),
                 inet:port_number()) -> gen_statem:start_ret().
start_link(Aid, MmeAddr, MmePort) ->
    gen_statem:start_link(?MODULE, [Aid, MmeAddr, MmePort], []).


-spec send_data(pid(), binary()) -> ok.
send_data(Pid, Data) ->
    gen_statem:cast(Pid, {send_data, Data}).


-spec shutdown(pid()) -> ok.
shutdown(Pid) ->
    gen_statem:stop(Pid).


%% ------------------------------------------------------------------
%% gen_statem API
%% ------------------------------------------------------------------

init([Aid, MmeAddr, MmePort]) ->
    process_flag(trap_exit, true),
    {ok, connecting,
     #{enb_aid => Aid,
       mme_addr => MmeAddr,
       mme_port => MmePort,
       tx_queue => [],
       priv => s1ap_proxy:init()}}.


callback_mode() ->
    [state_functions, state_enter].


%% CONNECTING state
connecting(enter, OldState,
           #{mme_addr := MmeAddr, mme_port := MmePort} = S) ->
    ?LOG_INFO("State change: ~p -> ~p", [OldState, ?FUNCTION_NAME]),
    %% Initiate connection establishment with the MME
    {ok, Sock} = sctp_client:connect(MmeAddr, MmePort),
    {next_state, connecting, S#{sock => Sock},
     [{state_timeout, 2_000, conn_est_timeout}]};

%% Handle connection establishment timeout
connecting(state_timeout, conn_est_timeout, _S) ->
    {stop, {shutdown, conn_est_timeout}};

%% Handle an eNB -> MME data forwarding request (queue)
connecting(cast, {send_data, Data},
           #{tx_queue := Pending} = S) ->
    s1gw_metrics:ctr_inc(?S1GW_CTR_S1AP_PROXY_UPLINK_PACKETS_QUEUED),
    s1gw_metrics:gauge_inc(?S1GW_GAUGE_S1AP_PROXY_UPLINK_PACKETS_QUEUED),
    {keep_state, S#{tx_queue := [Data | Pending]}};

%% Handle an #sctp_assoc_change event (connection state)
connecting(info, {sctp, _Socket, MmeAddr, MmePort,
                  {[], #sctp_assoc_change{state = ConnState,
                                          assoc_id = Aid}}}, S) ->
    case ConnState of
        comm_up ->
            ?LOG_NOTICE("MME connection (id=~p, ~p:~p) established",
                        [Aid, MmeAddr, MmePort]),
            {next_state, connected, S#{mme_aid => Aid}};
        _ ->
            ?LOG_NOTICE("MME connection establishment failed: ~p", [ConnState]),
            {stop, {shutdown, conn_est_fail}}
    end;

%% Catch-all for other kinds of SCTP events
connecting(info, {sctp, _Socket, MmeAddr, MmePort,
                  {AncData, Data}}, S) ->
    ?LOG_DEBUG("Unhandled SCTP event (~p:~p): ~p, ~p",
               [MmeAddr, MmePort, AncData, Data]),
    {keep_state, S};

connecting(Event, EventData, S) ->
    ?LOG_ERROR("Unexpected event ~p: ~p", [Event, EventData]),
    {keep_state, S}.


%% CONNECTED state
connected(enter, OldState, S0) ->
    ?LOG_INFO("State change: ~p -> ~p", [OldState, ?FUNCTION_NAME]),
    %% Send pending eNB -> MME messages (if any)
    S1 = sctp_send_pending(S0),
    {keep_state, S1};

%% Handle an eNB -> MME data forwarding request (forward)
connected(cast, {send_data, Data}, S0) ->
    S1 = sctp_send(Data, S0),
    {keep_state, S1};

%% Handle an #sctp_assoc_change event (MME connection state)
connected(info, {sctp, _Socket, MmeAddr, MmePort,
                 {[], #sctp_assoc_change{state = ConnState,
                                         assoc_id = Aid}}}, S) ->
    case ConnState of
        comm_up ->
            ?LOG_NOTICE("MME connection (id=~p, ~p:~p) is already established?!?",
                        [Aid, MmeAddr, MmePort]),
            {keep_state, S};
        _ ->
            ?LOG_NOTICE("MME connection state: ~p", [ConnState]),
            {stop, {shutdown, conn_fail}}
    end;

%% Handle an #sctp_sndrcvinfo event (MME -> eNB data)
connected(info, {sctp, _Socket, MmeAddr, MmePort,
                 {[#sctp_sndrcvinfo{assoc_id = Aid}], Data}},
          #{sock := Sock,
            enb_aid := EnbAid,
            mme_aid := Aid,
            priv := Priv} = S) ->
    ?LOG_DEBUG("MME connection (id=~p, ~p:~p) -> eNB: ~p",
               [Aid, MmeAddr, MmePort, Data]),
    {Action, NewPriv} = s1ap_proxy:process_pdu_safe(Data, Priv),
    case Action of
        {forward, FwdData} ->
            sctp_server:send_data(EnbAid, FwdData);
        {reply, ReData} ->
            ok = sctp_client:send_data({Sock, Aid}, ReData)
    end,
    {keep_state, S#{priv := NewPriv}};

%% Handle termination events of the child processes
connected(info, {'EXIT', Pid, Reason},
          #{priv := Priv} = S) ->
    ?LOG_DEBUG("Child process ~p terminated with reason ~p", [Pid, Reason]),
    NewPriv = s1ap_proxy:handle_exit(Pid, Priv),
    {keep_state, S#{priv := NewPriv}};

%% Catch-all for other kinds of SCTP events
connected(info, {sctp, _Socket, MmeAddr, MmePort,
                 {AncData, Data}}, S) ->
    ?LOG_DEBUG("Unhandled SCTP event (~p:~p): ~p, ~p",
               [MmeAddr, MmePort, AncData, Data]),
    {keep_state, S};

%% Catch-all handler for this state
connected(Event, EventData, S) ->
    ?LOG_ERROR("Unexpected event ~p: ~p", [Event, EventData]),
    {keep_state, S}.



code_change(_Vsn, State, S, _Extra) ->
    {ok, State, S}.


terminate(Reason, State, S) ->
    ?LOG_NOTICE("Terminating in state ~p, reason ~p", [State, Reason]),
    case S of
        #{sock := Sock,
          mme_aid := Aid,
          priv := Priv} ->
            s1ap_proxy:deinit(Priv),
            sctp_client:disconnect({Sock, Aid}),
            gen_sctp:close(Sock);
        #{sock := Sock} ->
            gen_sctp:close(Sock);
        _ -> ok
    end.


%% ------------------------------------------------------------------
%% private API
%% ------------------------------------------------------------------

%% Send a single message to the MME
sctp_send(Data,
          #{sock := Sock,
            enb_aid := EnbAid,
            mme_aid := Aid,
            priv := Priv} = S) ->
    {Action, NewPriv} = s1ap_proxy:process_pdu_safe(Data, Priv),
    case Action of
        {forward, FwdData} ->
            ok = sctp_client:send_data({Sock, Aid}, FwdData);
        {reply, ReData} ->
            sctp_server:send_data(EnbAid, ReData)
    end,
    S#{priv := NewPriv}.


%% Send pending messages to the MME
sctp_send_pending(#{tx_queue := Pending} = S) ->
    sctp_send_pending(lists:reverse(Pending), S).

sctp_send_pending([Data | Pending], S0) ->
    S1 = sctp_send(Data, S0),
    s1gw_metrics:gauge_dec(?S1GW_GAUGE_S1AP_PROXY_UPLINK_PACKETS_QUEUED),
    sctp_send_pending(Pending, S1);

sctp_send_pending([], S) ->
    S#{tx_queue := []}.

%% vim:set ts=4 sw=4 et:
