--- /usr/src/sys/netinet/ip_fw2.c.orig 2008-10-05 10:47:13.000000000 +0800 +++ /usr/src/sys/netinet/ip_fw2.c 2008-10-07 10:06:48.000000000 +0800 @@ -65,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -143,6 +144,7 @@ MALLOC_DEFINE(M_IPFW, "IpFw/IpAcct", "IpFw/IpAcct chain's"); MALLOC_DEFINE(M_IPFW_TBL, "ipfw_tbl", "IpFw tables"); +MALLOC_DEFINE(M_IPFW_PTBL, "ipfw_ptbl", "IpFw port tables"); #define IPFW_NAT_LOADED (ipfw_nat_ptr != NULL) ipfw_nat_t *ipfw_nat_ptr = NULL; ipfw_nat_cfg_t *ipfw_nat_cfg_ptr; @@ -1972,6 +1974,170 @@ return (0); } +static int +add_port_table_entry(struct ip_fw_chain *ch, uint16_t tbl, uint16_t port, + uint16_t port1, uint32_t value) +{ + struct ipfw_port_table_entry *ent, *ent_tmp, *ent_temp; + uint16_t tval, tval1; + + if (tbl >= IPFW_TABLES_MAX) + return (EINVAL); + + ent = malloc(sizeof(*ent), M_IPFW_PTBL, M_NOWAIT | M_ZERO); + if (ent == NULL) + return (ENOMEM); + ent->value = value; + ent->port = port; + ent->port1 = port1; + IPFW_WLOCK(&V_layer3_chain); + if (LIST_EMPTY(&(ch->ptables[tbl]->port_head))) + goto port_add; + LIST_FOREACH_SAFE(ent_tmp, &(ch->ptables[tbl]->port_head), _next, ent_temp) { + tval = ntohs(ent_tmp->port); + tval1 = ntohs(ent_tmp->port1); + /* check whether port is in list, it shouldn't be in port range that is already in list */ + if (tval <= ntohs(port) && tval1 >= ntohs(port)) { + // printf("first case, ent_tmp->port: %d <= port: %d, ent_tmp->port1: %d >= port: %d\n", + // ntohs(ent_tmp->port), ntohs(port), ntohs(ent_tmp->port1), ntohs(port)); + goto port_exist; + } else if (tval <= ntohs(port1) && tval1 >= ntohs(port1)) { + // printf("first case, ent_tmp->port: %d <= port1: %d, ent_tmp->port1: %d >= port1: %d\n", + // ntohs(ent_tmp->port), ntohs(port1), ntohs(ent_tmp->port1), ntohs(port1)); + goto port_exist; + } else if (tval >= ntohs(port) && tval1 <= ntohs(port1)) { + // printf("first case, ent_tmp->port: %d >= port: %d, ent_tmp->port1: %d <= port1: %d\n", + // ntohs(ent_tmp->port), ntohs(port), ntohs(ent_tmp->port1), ntohs(port1)); + goto port_exist; + } + } + port_add: + LIST_INSERT_HEAD(&(ch->ptables[tbl]->port_head), ent, _next); + IPFW_WUNLOCK(&V_layer3_chain); + return (0); + port_exist: + free(ent, M_IPFW_PTBL); + IPFW_WUNLOCK(&V_layer3_chain); + return (EEXIST); +} + +static int +del_port_table_entry(struct ip_fw_chain *ch, uint16_t tbl, uint16_t port + ) +{ + struct ipfw_port_table_entry *ent, *ent_temp; + + if (tbl >= IPFW_TABLES_MAX) + return (EINVAL); + // printf("delete port: %d\n",port); + IPFW_WLOCK(ch); + LIST_FOREACH_SAFE(ent, &(ch->ptables[tbl]->port_head), _next, ent_temp) { + if (ent->port == port) { + LIST_REMOVE(ent, _next); + free(ent, M_IPFW_PTBL); + IPFW_WUNLOCK(ch); + return (0); + } + } + IPFW_WUNLOCK(ch); + return (0); +} + +static int +flush_port_table(struct ip_fw_chain *ch, uint16_t tbl) +{ + struct ipfw_port_table_entry *ent, *ent_temp; + + IPFW_WLOCK_ASSERT(ch); + + if (tbl >= IPFW_TABLES_MAX) + return (EINVAL); + + LIST_FOREACH_SAFE(ent, &(ch->ptables[tbl]->port_head), _next, ent_temp) { + LIST_REMOVE(ent, _next); + free(ent, M_IPFW_PTBL); + } + return (0); +} + +static void +flush_port_tables(struct ip_fw_chain *ch) +{ + uint16_t tbl; + + IPFW_WLOCK_ASSERT(ch); + + for (tbl = 0; tbl < IPFW_TABLES_MAX; tbl++) + flush_port_table(ch, tbl); +} + +static int +init_port_tables(struct ip_fw_chain *ch) +{ + int i; + struct ipfw_port_table *p = NULL; + + for (i = 0; i < IPFW_TABLES_MAX; i++) { + p = malloc(sizeof(*p), M_IPFW_PTBL, M_NOWAIT | M_ZERO); + ch->ptables[i] = p; + LIST_INIT(&(ch->ptables[i]->port_head)); + } + return (0); +} + +static int +lookup_port_table(struct ip_fw_chain *ch, uint16_t tbl, uint16_t port, + uint32_t *val) +{ + struct ipfw_port_table_entry *ent, *ent_temp; + + if (tbl >= IPFW_TABLES_MAX) + return (0); + + LIST_FOREACH_SAFE(ent, &(ch->ptables[tbl]->port_head), _next, ent_temp) { + if (ntohs(ent->port) <= port && ntohs(ent->port1) >= port) { + *val = ent->value; + return (1); + } + } + return (0); +} + +static int +count_port_table(struct ip_fw_chain *ch, uint32_t tbl, uint32_t *cnt) +{ + struct ipfw_port_table_entry *ent, *ent_temp; + + if (tbl >= IPFW_TABLES_MAX) + return (EINVAL); + + *cnt = 0; + LIST_FOREACH_SAFE(ent, &(ch->ptables[tbl]->port_head), _next, ent_temp) { + (*cnt)++; + } + return (0); +} + +static int +dump_port_table(struct ip_fw_chain *ch, struct ipfw_port_table *tbl) +{ + struct ipfw_port_table_entry *ent, *ent_temp; + + if (tbl->tbl >= IPFW_TABLES_MAX) + return (EINVAL); + tbl->cnt = 0; + //tbl = ch->ptables[tbl->tbl]; + //tbl->port_head = ch->ptables[tbl->tbl]->port_head; + LIST_FOREACH_SAFE(ent, &(ch->ptables[tbl->tbl]->port_head), _next, ent_temp) { + //printf("tbl->tbl: %d, port: %d\n",tbl->tbl, ntohs(ent->port)); + tbl->ent[tbl->cnt].port = ent->port; + tbl->ent[tbl->cnt].port1 = ent->port1; + tbl->ent[tbl->cnt].value = ent->value; + tbl->cnt++; + } + return (0); +} + static void fill_ugid_cache(struct inpcb *inp, struct ip_fw_ugid *ugp) { @@ -2777,6 +2943,26 @@ } break; + case O_IP_SRCPORT_LOOKUP: + case O_IP_DSTPORT_LOOKUP: + if ((proto==IPPROTO_UDP || proto==IPPROTO_TCP) + && offset == 0) { + uint16_t x = + (cmd->opcode == O_IP_DSTPORT_LOOKUP) ? + dst_port : src_port; + uint32_t v; + + match = lookup_port_table(chain, cmd->arg1, x, &v); + if (!match) + break; + if (cmdlen == F_INSN_SIZE(ipfw_insn_u16)) + match = + ((ipfw_insn_u32 *)cmd)->d[0] == v; + else + tablearg = v; + } + break; + case O_ICMPTYPE: match = (offset == 0 && proto==IPPROTO_ICMP && icmptype_match(ICMP(ulp), (ipfw_insn_u32 *)cmd) ); @@ -3922,6 +4108,18 @@ goto bad_size; break; + case O_IP_SRCPORT_LOOKUP: + case O_IP_DSTPORT_LOOKUP: + if (cmd->arg1 >= IPFW_TABLES_MAX) { + printf("ipfw: invalid port table number %d\n", + cmd->arg1); + return (EINVAL); + } + if (cmdlen != F_INSN_SIZE(ipfw_insn) && + cmdlen != F_INSN_SIZE(ipfw_insn_u16)) + goto bad_size; + break; + case O_MACADDR2: if (cmdlen != F_INSN_SIZE(ipfw_insn_mac)) goto bad_size; @@ -4384,6 +4582,92 @@ } break; + case IP_FW_PORT_TABLE_ADD: + { + struct ipfw_port_table_entry ent; + + error = sooptcopyin(sopt, &ent, + sizeof(ent), sizeof(ent)); + if (error) + break; + error = add_port_table_entry(&V_layer3_chain, ent.tbl, + ent.port, ent.port1, ent.value); + } + break; + + case IP_FW_PORT_TABLE_DEL: + { + struct ipfw_port_table_entry ent; + + error = sooptcopyin(sopt, &ent, + sizeof(ent), sizeof(ent)); + if (error) + break; + //printf("ent.port: %d\n", ent.port); + error = del_port_table_entry(&V_layer3_chain, ent.tbl, + ent.port); + } + break; + + case IP_FW_PORT_TABLE_FLUSH: + { + u_int16_t tbl; + + error = sooptcopyin(sopt, &tbl, + sizeof(tbl), sizeof(tbl)); + if (error) + break; + IPFW_WLOCK(&V_layer3_chain); + error = flush_port_table(&V_layer3_chain, tbl); + IPFW_WUNLOCK(&V_layer3_chain); + } + break; + + case IP_FW_PORT_TABLE_GETSIZE: + { + u_int32_t tbl, cnt; + + if ((error = sooptcopyin(sopt, &tbl, sizeof(tbl), + sizeof(tbl)))) + break; + IPFW_RLOCK(&V_layer3_chain); + error = count_port_table(&V_layer3_chain, tbl, &cnt); + IPFW_RUNLOCK(&V_layer3_chain); + if (error) + break; + error = sooptcopyout(sopt, &cnt, sizeof(cnt)); + } + break; + + case IP_FW_PORT_TABLE_LIST: + { + struct ipfw_port_table *tbl; + + if (sopt->sopt_valsize < sizeof(*tbl)) { + error = EINVAL; + break; + } + size = sopt->sopt_valsize; + tbl = malloc(size, M_TEMP, M_WAITOK); + error = sooptcopyin(sopt, tbl, size, sizeof(*tbl)); + if (error) { + free(tbl, M_TEMP); + break; + } + tbl->size = (size - sizeof(*tbl)) / + sizeof(struct ipfw_port_table_entry); + IPFW_RLOCK(&V_layer3_chain); + error = dump_port_table(&V_layer3_chain, tbl); + IPFW_RUNLOCK(&V_layer3_chain); + if (error) { + free(tbl, M_TEMP); + break; + } + error = sooptcopyout(sopt, tbl, size); + free(tbl, M_TEMP); + } + break; + case IP_FW_NAT_CFG: { if (IPFW_NAT_LOADED) @@ -4606,6 +4890,13 @@ uma_zdestroy(ipfw_dyn_rule_zone); return (error); } + error = init_port_tables(&V_layer3_chain); + if (error) { + IPFW_DYN_LOCK_DESTROY(); + IPFW_LOCK_DESTROY(&V_layer3_chain); + uma_zdestroy(ipfw_dyn_rule_zone); + return (error); + } ip_fw_ctl_ptr = ipfw_ctl; ip_fw_chk_ptr = ipfw_chk; callout_reset(&V_ipfw_timeout, hz, ipfw_tick, NULL); @@ -4623,6 +4914,7 @@ callout_drain(&V_ipfw_timeout); IPFW_WLOCK(&V_layer3_chain); flush_tables(&V_layer3_chain); + flush_port_tables(&V_layer3_chain); V_layer3_chain.reap = NULL; free_chain(&V_layer3_chain, 1 /* kill default rule */); reap = V_layer3_chain.reap, V_layer3_chain.reap = NULL;