%% Copyright (C) 2025 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
%% Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
%%
%% All Rights Reserved
%%
%% SPDX-License-Identifier: MPL-2.0
%%
%% This Source Code Form is subject to the terms of the Mozilla Public
%% License, v. 2.0.  If a copy of the MPL was not distributed with this
%% file, You can obtain one at http://mozilla.org/MPL/2.0/.

-module(enftables).

-export([run_cmd/1,
         run_cmd/2]).
-export([nft_cmd_add_table/1,
         nft_cmd_add_table/2,
         nft_cmd_add_table/3,
         nft_cmd_del_table/1,
         nft_cmd_del_table/2]).
-export([nft_cmd_add_rule/3,
         nft_cmd_add_rule/4,
         nft_cmd_del_rule/3,
         nft_cmd_del_rule/4]).
-export([nft_counter/2,
         nft_counter/3,
         nft_cmd_add_counter/1,
         nft_cmd_del_counter/1]).
-export([nft_cmd_list_ruleset/0,
         nft_cmd_list_tables/0,
         nft_cmd_list_chain/2,
         nft_cmd_list_chain/3,
         nft_cmd_list_counters/1,
         nft_cmd_list_counters/2]).
-export([nft_expr_match_payload/3,
         nft_expr_match_ip_proto/2,
         nft_expr_match_ip_saddr/2,
         nft_expr_match_ip_daddr/2,
         nft_expr_match_udp_dport/2,
         nft_expr_counter/1,
         nft_expr_accept/0]).

-nifs([run_cmd_nif/1]).
-on_load(init/0).

-define(LIBNAME, ?MODULE).
-define(DEFAULT_FAMILY, "inet").

-type nft_cmd_type() :: add
                      | replace
                      | create
                      | insert
                      | delete
                      | list
                      | reset
                      | flush
                      | rename.
-type nft_cmd() :: #{nft_cmd_type() => map()}.
-type nft_expr() :: map().
-type nft_counter() :: map().

%% Operators: "==" | "!=" | "<" | ">" | "<=" | ">=" | "in"
-type nft_op() :: string().

-type cmds() :: [nft_cmd()].    %% list of commands
-type rsps() :: [map()].        %% list of responses

-type fmt() :: json | text.
-type result() :: ok
                | {ok, rsps()}
                | {error, term()}.

-export_type([nft_cmd_type/0,
              nft_cmd/0,
              result/0]).


%% ------------------------------------------------------------------
%% public API
%% ------------------------------------------------------------------

-spec run_cmd(Cmds) -> result()
    when Cmds :: cmds().
run_cmd(Cmds) ->
    run_cmd(Cmds, json).


-spec run_cmd(Cmds, Format) -> result()
    when Cmds :: cmds() | string(),
         Format :: fmt().
run_cmd(Cmds, json) ->
    Encoded = jiffy:encode(#{<< "nftables" >> => Cmds}),
    parse_result(run_cmd_nif(<< Encoded/bytes, 16#00 >>));

run_cmd(Cmds, text) ->
    Encoded = list_to_binary(Cmds),
    parse_result(run_cmd_nif(<< Encoded/bytes, 16#00 >>)).


%% ------------------------------------------------------------------
%% public API :: command templates
%% ------------------------------------------------------------------

-spec nft_cmd_add_table(TName) -> nft_cmd()
    when TName :: string().
nft_cmd_add_table(TName) ->
    nft_cmd_add_table(?DEFAULT_FAMILY, TName).

-spec nft_cmd_add_table(TName, Flags) -> nft_cmd()
    when TName :: string(),
         Flags :: [binary()].
nft_cmd_add_table(TName, Flags) ->
    nft_cmd_add_table(?DEFAULT_FAMILY, TName, Flags).

-spec nft_cmd_add_table(Family, TName, Flags) -> nft_cmd()
    when Family :: string(),
         TName :: string(),
         Flags :: [binary()].
nft_cmd_add_table(Family, TName, Flags) ->
    T = #{family => list_to_binary(Family),
          name => list_to_binary(TName),
          flags => Flags
         },
    #{add => #{table => T}}.


-spec nft_cmd_del_table(TName) -> nft_cmd()
    when TName :: string().
nft_cmd_del_table(TName) ->
    nft_cmd_del_table(?DEFAULT_FAMILY, TName).

-spec nft_cmd_del_table(Family, TName) -> nft_cmd()
    when Family :: string(),
         TName :: string().
nft_cmd_del_table(Family, TName) ->
    T = #{family => list_to_binary(Family),
          name => list_to_binary(TName)
         },
    #{delete => #{table => T}}.


-spec nft_cmd_add_rule(TName, CName, Expr) -> nft_cmd()
    when TName :: string(),
         CName :: string(),
         Expr :: [nft_expr()].
nft_cmd_add_rule(TName, CName, Expr) ->
    nft_cmd_add_rule(?DEFAULT_FAMILY, TName, CName, Expr).

-spec nft_cmd_add_rule(Family, TName, CName, Expr) -> nft_cmd()
    when Family :: string(),
         TName :: string(),
         CName :: string(),
         Expr :: [nft_expr()].
nft_cmd_add_rule(Family, TName, CName, Expr) ->
    R = #{family => list_to_binary(Family),
          table => list_to_binary(TName),
          chain => list_to_binary(CName),
          expr => Expr
         },
    #{add => #{rule => R}}.


-spec nft_cmd_del_rule(TName, CName, Handle) -> nft_cmd()
    when TName :: string(),
         CName :: string(),
         Handle :: integer().
nft_cmd_del_rule(TName, CName, Handle) ->
    nft_cmd_del_rule(?DEFAULT_FAMILY, TName, CName, Handle).

-spec nft_cmd_del_rule(Family, TName, CName, Handle) -> nft_cmd()
    when Family :: string(),
         TName :: string(),
         CName :: string(),
         Handle :: integer().
nft_cmd_del_rule(Family, TName, CName, Handle) ->
    R = #{family => list_to_binary(Family),
          table => list_to_binary(TName),
          chain => list_to_binary(CName),
          handle => Handle
         },
    #{delete => #{rule => R}}.


-spec nft_counter(TName, Name) -> nft_counter()
    when TName :: string(),
         Name :: string().
nft_counter(TName, Name) ->
    nft_counter(?DEFAULT_FAMILY, TName, Name).

-spec nft_counter(Family, TName, Name) -> nft_counter()
    when Family :: string(),
         TName :: string(),
         Name :: string().
nft_counter(Family, TName, Name) ->
    #{family => list_to_binary(Family),
      table => list_to_binary(TName),
      name => list_to_binary(Name)
     }.


-spec nft_cmd_add_counter(nft_counter()) -> nft_cmd().
nft_cmd_add_counter(C) ->
    #{add => #{counter => C}}.


-spec nft_cmd_del_counter(nft_counter()) -> nft_cmd().
nft_cmd_del_counter(C) ->
    #{delete => #{counter => C}}.


-spec nft_cmd_list_ruleset() -> nft_cmd().
nft_cmd_list_ruleset() ->
    #{list => #{ruleset => null}}.


-spec nft_cmd_list_tables() -> nft_cmd().
nft_cmd_list_tables() ->
    #{list => #{tables => null}}.


-spec nft_cmd_list_chain(TName, CName) -> nft_cmd()
    when TName :: string(),
         CName :: string().
nft_cmd_list_chain(TName, CName) ->
    nft_cmd_list_chain(?DEFAULT_FAMILY, TName, CName).


-spec nft_cmd_list_chain(Family, TName, CName) -> nft_cmd()
    when Family :: string(),
         TName :: string(),
         CName :: string().
nft_cmd_list_chain(Family, TName, CName) ->
    C = #{family => list_to_binary(Family),
          table => list_to_binary(TName),
          name => list_to_binary(CName)
         },
    #{list => #{chain => C}}.


-spec nft_cmd_list_counters(TName) -> nft_cmd()
    when TName :: string().
nft_cmd_list_counters(TName) ->
    nft_cmd_list_counters(?DEFAULT_FAMILY, TName).


-spec nft_cmd_list_counters(Family, TName) -> nft_cmd()
    when Family :: string(),
         TName :: string().
nft_cmd_list_counters(Family, TName) ->
    T = #{family => list_to_binary(Family),
          name => list_to_binary(TName)
         },
    #{list => #{counters => #{table => T}}}.


%% ------------------------------------------------------------------
%% public API :: expression templates
%% ------------------------------------------------------------------

-spec nft_expr_match_payload({Proto, Field}, Value, Op) -> map()
    when Proto :: string(),
         Field :: string(),
         Value :: term(),
         Op :: nft_op().
nft_expr_match_payload({Proto, Field}, Value, Op) ->
    Left = #{payload => #{protocol => list_to_binary(Proto),
                          field => list_to_binary(Field)}},
    #{match => #{left => Left,
                 right => Value,
                 op => list_to_binary(Op)}}.


-spec nft_expr_match_ip_proto(Proto, Op) -> map()
    when Proto :: string(),
         Op :: nft_op().
nft_expr_match_ip_proto(Proto, Op) ->
    nft_expr_match_payload({"ip", "protocol"},
                           list_to_binary(Proto), Op).


-spec nft_expr_match_ip_saddr(Addr, Op) -> map()
    when Addr :: string(),
         Op :: nft_op().
nft_expr_match_ip_saddr(Addr, Op) ->
    nft_expr_match_payload({"ip", "saddr"},
                           list_to_binary(Addr), Op).


-spec nft_expr_match_ip_daddr(Addr, Op) -> map()
    when Addr :: string(),
         Op :: nft_op().
nft_expr_match_ip_daddr(Addr, Op) ->
    nft_expr_match_payload({"ip", "daddr"},
                           list_to_binary(Addr), Op).


-spec nft_expr_match_udp_dport(DPort, Op) -> map()
    when DPort :: non_neg_integer(),
         Op :: nft_op().
nft_expr_match_udp_dport(DPort, Op) ->
    nft_expr_match_payload({"udp", "dport"}, DPort, Op).


-spec nft_expr_counter(string()) -> map().
nft_expr_counter(Name) ->
    #{counter => list_to_binary(Name)}.


nft_expr_accept() ->
    #{accept => null}.


%% ------------------------------------------------------------------
%% private API
%% ------------------------------------------------------------------

init() ->
    SoName = case code:priv_dir(?LIBNAME) of
        {error, bad_name} ->
            case filelib:is_dir(filename:join(["..", priv])) of
                true ->
                    filename:join(["..", priv, ?LIBNAME]);
                _ ->
                    filename:join([priv, ?LIBNAME])
            end;
        Dir ->
            filename:join(Dir, ?LIBNAME)
    end,
    erlang:load_nif(SoName, 0).


-spec run_cmd_nif(Cmd) -> Result
    when Cmd :: binary(),
         Result :: {ok, binary()} | {error, term()}.
run_cmd_nif(_Cmd) ->
    erlang:nif_error(nif_library_not_loaded).


-spec parse_result(Res) -> result()
    when Res :: {ok, binary()} | {error, term()}.
parse_result({ok, << >>}) -> ok;

parse_result({ok, << Res/bytes >>}) ->
    case jiffy:decode(Res, [return_maps]) of
        #{<< "nftables" >> := Rsps} ->
            %% remove the first "metainfo" item
            [#{<< "metainfo" >> := _} | Items] = Rsps,
            {ok, Items};
        _ ->
            {error, {json_decode, unexpected}}
    end;

parse_result(Error) -> Error.

%% vim:set ts=4 sw=4 et:
