%% Copyright (C) 2024 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
%% Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
%%
%% All Rights Reserved
%%
%% SPDX-License-Identifier: AGPL-3.0-or-later
%%
%% This program is free software; you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as
%% published by the Free Software Foundation; either version 3 of the
%% License, or (at your option) any later version.
%%
%% This program is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%% GNU General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with this program.  If not, see <https://www.gnu.org/licenses/>.
%%
%% Additional Permission under GNU AGPL version 3 section 7:
%%
%% If you modify this Program, or any covered work, by linking or
%% combining it with runtime libraries of Erlang/OTP as released by
%% Ericsson on https://www.erlang.org (or a modified version of these
%% libraries), containing parts covered by the terms of the Erlang Public
%% License (https://www.erlang.org/EPLICENSE), the licensors of this
%% Program grant you additional permission to convey the resulting work
%% without the need to license the runtime libraries of Erlang/OTP under
%% the GNU Affero General Public License. Corresponding Source for a
%% non-source form of such a combination shall include the source code
%% for the parts of the runtime libraries of Erlang/OTP used as well as
%% that of the covered work.

-module(sctp_server).
-behaviour(gen_server).

-export([init/1,
         handle_info/2,
         handle_call/3,
         handle_cast/2,
         terminate/2]).
-export([start_link/3,
         send_data/2,
         shutdown/0]).

-include_lib("kernel/include/logger.hrl").
-include_lib("kernel/include/inet.hrl").
-include_lib("kernel/include/inet_sctp.hrl").

-include("s1ap.hrl").

-record(server_state, {sock, clients, mme_addr_port}).
-record(client_state, {addr_port, pid}).

%% ------------------------------------------------------------------
%% public API
%% ------------------------------------------------------------------

start_link(BindAddr, BindPort, MmeAddrPort) ->
    gen_server:start_link({local, ?MODULE}, ?MODULE,
                          [BindAddr, BindPort, MmeAddrPort],
                          []).


send_data(Aid, Data) ->
    gen_server:cast(?MODULE, {send_data, Aid, Data}).


shutdown() ->
    gen_server:stop(?MODULE).


%% ------------------------------------------------------------------
%% gen_server API
%% ------------------------------------------------------------------

init([BindAddrStr, BindPort, MmeAddrPort]) when is_list(BindAddrStr) ->
    {ok, BindAddr} = inet:parse_address(BindAddrStr),
    init([BindAddr, BindPort, MmeAddrPort]);

init([BindAddr, BindPort, MmeAddrPort]) ->
    process_flag(trap_exit, true),
    {ok, Sock} = gen_sctp:open([{ip, BindAddr},
                                {port, BindPort},
                                {type, seqpacket},
                                {reuseaddr, true},
                                {active, true}]),
    ?LOG_INFO("SCTP server listening on ~w:~w", [BindAddr, BindPort]),
    ok = gen_sctp:listen(Sock, true),
    {ok, #server_state{sock = Sock,
                       clients = dict:new(),
                       mme_addr_port = MmeAddrPort}}.


handle_call(Info, From, State) ->
    ?LOG_ERROR("unknown ~p() from ~p: ~p", [?FUNCTION_NAME, From, Info]),
    {reply, {error, not_implemented}, State}.


handle_cast({send_data, Aid, Data}, State) ->
    gen_sctp:send(State#server_state.sock,
                  #sctp_sndrcvinfo{stream = ?S1AP_SCTP_STREAM,
                                   ppid = ?S1AP_SCTP_PPID,
                                   assoc_id = Aid}, Data),
    {noreply, State};

handle_cast(Info, State) ->
    ?LOG_ERROR("unknown ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, State}.


%% Handle SCTP events coming from gen_sctp module
handle_info({sctp, _Socket, FromAddr, FromPort, {AncData, Data}}, State) ->
    NewState = sctp_recv(State, {FromAddr, FromPort, AncData, Data}),
    {noreply, NewState};

%% Handle termination events of the child processes
handle_info({'EXIT', Pid, Reason},
            #server_state{sock = Sock, clients = Clients} = State) ->
    ?LOG_DEBUG("Child process ~p terminated with reason ~p", [Pid, Reason]),
    case client_find(State, Pid) of
        {ok, {Aid, _Client}} ->
            %% gracefully close the eNB connection
            gen_sctp:eof(Sock, #sctp_assoc_change{assoc_id = Aid}),
            {noreply, State#server_state{clients = dict:erase(Aid, Clients)}};
        error ->
            {noreply, State}
    end;

%% Catch-all for unknown messages
handle_info(Info, State) ->
    ?LOG_ERROR("unknown ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, State}.


terminate(Reason, State) ->
    ?LOG_NOTICE("Terminating, reason ~p", [Reason]),
    close_conns(State),
    gen_sctp:close(State#server_state.sock),
    ok.

%% ------------------------------------------------------------------
%% private API
%% ------------------------------------------------------------------

%% Handle an #sctp_assoc_change event (connection state)
sctp_recv(State, {FromAddr, FromPort, [],
                  #sctp_assoc_change{state = ConnState,
                                     assoc_id = Aid}}) ->
    case ConnState of
        comm_up ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) established", [Aid, FromAddr, FromPort]),
            Clients = client_add(State#server_state.clients, Aid, FromAddr, FromPort,
                                 State#server_state.mme_addr_port);
        shutdown_comp ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) closed", [Aid, FromAddr, FromPort]),
            Clients = client_del(State#server_state.clients, Aid);
        comm_lost ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) lost", [Aid, FromAddr, FromPort]),
            Clients = client_del(State#server_state.clients, Aid);
        _ ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) state ~p",
                        [Aid, FromAddr, FromPort, ConnState]),
            Clients = State#server_state.clients
    end,
    State#server_state{clients = Clients};

%% Handle an #sctp_sndrcvinfo event (incoming data)
sctp_recv(State, {FromAddr, FromPort,
                  [#sctp_sndrcvinfo{assoc_id = Aid}], Data}) ->
    ?LOG_DEBUG("eNB connection (id=~p, ~p:~p) -> MME: ~p", [Aid, FromAddr, FromPort, Data]),
    case dict:find(Aid, State#server_state.clients) of
        {ok, #client_state{pid = Pid}} ->
            sctp_proxy:send_data(Pid, Data);
        error ->
            ?LOG_ERROR("eNB connection (id=~p, ~p:~p) is not known to us?!?",
                       [Aid, FromAddr, FromPort])
    end,
    State;

%% Catch-all for other kinds of SCTP events
sctp_recv(State, {FromAddr, FromPort, AncData, Data}) ->
    ?LOG_DEBUG("Unhandled SCTP event (~p:~p): ~p, ~p",
               [FromAddr, FromPort, AncData, Data]),
    State.


%% Add a new client to the list, spawning a proxy process
client_add(Clients, Aid, FromAddr, FromPort, {MmeAddr, MmePort}) ->
    {ok, Pid} = sctp_proxy:start_link(Aid, MmeAddr, MmePort),
    NewClient = #client_state{addr_port = {FromAddr, FromPort}, pid = Pid},
    dict:store(Aid, NewClient, Clients).


%% Delete an existing client from the list, stopping the proxy process
client_del(Clients, Aid) ->
    case dict:find(Aid, Clients) of
        {ok, Client} ->
            %% the proxy process might be already dead, so we guard
            %% against exceptions like noproc or {nodedown,Node}.
            catch sctp_proxy:shutdown(Client#client_state.pid),
            dict:erase(Aid, Clients);
        error ->
            Clients
    end.


%% Find a client by process ID
client_find(#server_state{clients = Clients}, Pid) ->
    client_find(dict:to_list(Clients), Pid);

client_find([{Aid, Client} | Clients], Pid) ->
    case Client of
        #client_state{pid = Pid} ->
            {ok, {Aid, Client}};
        _ ->
            client_find(Clients, Pid)
    end;

client_find([], _Pid) ->
    error.


%% Gracefully terminate client connections
close_conns(#server_state{sock = Sock, clients = Clients}) ->
    close_conns(Sock, dict:to_list(Clients)).

close_conns(Sock, [{Aid, Client} | Clients]) ->
    {FromAddr, FromPort} = Client#client_state.addr_port,
    ?LOG_NOTICE("Terminating eNB connection (id=~p, ~p:~p)", [Aid, FromAddr, FromPort]),
    %% request to terminate an MME connection
    %% the proxy process might be already dead, so we guard
    %% against exceptions like noproc or {nodedown,Node}.
    catch sctp_proxy:shutdown(Client#client_state.pid),
    %% gracefully close an eNB connection
    gen_sctp:eof(Sock, #sctp_assoc_change{assoc_id = Aid}),
    %% ... and so for the remaining clients
    close_conns(Sock, Clients);

close_conns(_Sock, []) ->
    ok.

%% vim:set ts=4 sw=4 et:
