Index: ip_fw2.c =================================================================== RCS file: /home/ncvs/src/sys/netinet/ip_fw2.c,v retrieving revision 1.84 diff -u -r1.84 ip_fw2.c --- ip_fw2.c 2 Nov 2004 22:22:21 -0000 1.84 +++ ip_fw2.c 3 Nov 2004 16:43:49 -0000 @@ -46,6 +46,7 @@ #if IPFW2 #include #include +#include #include #include #include @@ -119,18 +120,54 @@ struct ip_fw *rules; /* list of rules */ struct ip_fw *reap; /* list of rules to reap */ struct mtx mtx; /* lock guarding rule list */ + int busy_count; /* busy count for rw locks */ + int want_write; + struct cv cv; }; #define IPFW_LOCK_INIT(_chain) \ mtx_init(&(_chain)->mtx, "IPFW static rules", NULL, \ MTX_DEF | MTX_RECURSE) #define IPFW_LOCK_DESTROY(_chain) mtx_destroy(&(_chain)->mtx) -#define IPFW_LOCK(_chain) mtx_lock(&(_chain)->mtx) -#define IPFW_UNLOCK(_chain) mtx_unlock(&(_chain)->mtx) -#define IPFW_LOCK_ASSERT(_chain) do { \ +#define IPFW_WLOCK_ASSERT(_chain) do { \ mtx_assert(&(_chain)->mtx, MA_OWNED); \ NET_ASSERT_GIANT(); \ } while (0) +static __inline void +IPFW_RLOCK(struct ip_fw_chain *chain) +{ + mtx_lock(&chain->mtx); + chain->busy_count++; + mtx_unlock(&chain->mtx); +} + +static __inline void +IPFW_RUNLOCK(struct ip_fw_chain *chain) +{ + mtx_lock(&chain->mtx); + chain->busy_count--; + if (chain->busy_count == 0 && chain->want_write) + cv_signal(&chain->cv); + mtx_unlock(&chain->mtx); +} + +static __inline void +IPFW_WLOCK(struct ip_fw_chain *chain) +{ + mtx_lock(&chain->mtx); + chain->want_write++; + while (chain->busy_count > 0) + cv_wait(&chain->cv, &chain->mtx); +} + +static __inline void +IPFW_WUNLOCK(struct ip_fw_chain *chain) +{ + chain->want_write--; + cv_signal(&chain->cv); + mtx_unlock(&chain->mtx); +} + /* * list of rules for layer 3 */ @@ -1854,7 +1891,7 @@ args->f_id.dst_port = dst_port = ntohs(dst_port); after_ip_checks: - IPFW_LOCK(chain); /* XXX expensive? can we run lock free? */ + IPFW_RLOCK(chain); mtag = m_tag_find(m, PACKET_TAG_DIVERT, NULL); if (args->rule) { /* @@ -1866,7 +1903,7 @@ * the caller. */ if (fw_one_pass) { - IPFW_UNLOCK(chain); /* XXX optimize */ + IPFW_RUNLOCK(chain); return 0; } @@ -1883,13 +1920,13 @@ f = chain->rules; if (args->eh == NULL && skipto != 0) { if (skipto >= IPFW_DEFAULT_RULE) { - IPFW_UNLOCK(chain); + IPFW_RUNLOCK(chain); return(IP_FW_PORT_DENY_FLAG); /* invalid */ } while (f && f->rulenum <= skipto) f = f->next; if (f == NULL) { /* drop packet */ - IPFW_UNLOCK(chain); + IPFW_RUNLOCK(chain); return(IP_FW_PORT_DENY_FLAG); } } @@ -2438,7 +2475,7 @@ if (mtag == NULL) { /* XXX statistic */ /* drop packet */ - IPFW_UNLOCK(chain); + IPFW_RUNLOCK(chain); return IP_FW_PORT_DENY_FLAG; } dt = (struct divert_tag *)(mtag+1); @@ -2514,7 +2551,7 @@ } /* end of outer for, scan rules */ printf("ipfw: ouch!, skip past end of rules, denying packet\n"); - IPFW_UNLOCK(chain); + IPFW_RUNLOCK(chain); return(IP_FW_PORT_DENY_FLAG); done: @@ -2522,7 +2559,7 @@ f->pcnt++; f->bcnt += pktlen; f->timestamp = time_second; - IPFW_UNLOCK(chain); + IPFW_RUNLOCK(chain); return retval; pullup_failed: @@ -2540,7 +2577,7 @@ { struct ip_fw *rule; - IPFW_LOCK_ASSERT(chain); + IPFW_WLOCK_ASSERT(chain); for (rule = chain->rules; rule; rule = rule->next) rule->next_rule = NULL; @@ -2555,7 +2592,7 @@ { struct ip_fw *rule; - IPFW_LOCK(&layer3_chain); + IPFW_WLOCK(&layer3_chain); for (rule = layer3_chain.rules; rule; rule = rule->next) { ipfw_insn_pipe *cmd = (ipfw_insn_pipe *)ACTION_PTR(rule); @@ -2571,7 +2608,7 @@ !bcmp(&cmd->pipe_ptr, &match, sizeof(match)) ) bzero(&cmd->pipe_ptr, sizeof(cmd->pipe_ptr)); } - IPFW_UNLOCK(&layer3_chain); + IPFW_WUNLOCK(&layer3_chain); } /* @@ -2601,8 +2638,7 @@ rule->bcnt = 0; rule->timestamp = 0; - IPFW_LOCK(chain); - + IPFW_WLOCK(chain); if (chain->rules == NULL) { /* default rule */ chain->rules = rule; goto done; @@ -2649,7 +2685,7 @@ done: static_count++; static_len += l; - IPFW_UNLOCK(chain); + IPFW_WUNLOCK(chain); DEB(printf("ipfw: installed rule %d, static count now %d\n", rule->rulenum, static_count);) return (0); @@ -2669,7 +2705,7 @@ struct ip_fw *n; int l = RULESIZE(rule); - IPFW_LOCK_ASSERT(chain); + IPFW_WLOCK_ASSERT(chain); n = rule->next; IPFW_DYN_LOCK(); @@ -2715,7 +2751,7 @@ { struct ip_fw *prev, *rule; - IPFW_LOCK_ASSERT(chain); + IPFW_WLOCK_ASSERT(chain); flush_rule_ptrs(chain); /* more efficient to do outside the loop */ for (prev = NULL, rule = chain->rules; rule ; ) @@ -2763,7 +2799,7 @@ return EINVAL; } - IPFW_LOCK(chain); + IPFW_WLOCK(chain); rule = chain->rules; chain->reap = NULL; switch (cmd) { @@ -2774,7 +2810,7 @@ for (; rule->rulenum < rulenum; prev = rule, rule = rule->next) ; if (rule->rulenum != rulenum) { - IPFW_UNLOCK(chain); + IPFW_WUNLOCK(chain); return EINVAL; } @@ -2827,7 +2863,7 @@ */ rule = chain->reap; chain->reap = NULL; - IPFW_UNLOCK(chain); + IPFW_WUNLOCK(chain); if (rule) reap_rules(rule); return 0; @@ -2862,7 +2898,7 @@ struct ip_fw *rule; char *msg; - IPFW_LOCK(chain); + IPFW_WLOCK(chain); if (rulenum == 0) { norule_counter = 0; for (rule = chain->rules; rule; rule = rule->next) @@ -2885,13 +2921,13 @@ break; } if (!cleared) { /* we did not find any matching rules */ - IPFW_UNLOCK(chain); + IPFW_WUNLOCK(chain); return (EINVAL); } msg = log_only ? "ipfw: Entry %d logging count reset.\n" : "ipfw: Entry %d cleared.\n"; } - IPFW_UNLOCK(chain); + IPFW_WUNLOCK(chain); if (fw_verbose) log(LOG_SECURITY | LOG_NOTICE, msg, rulenum); @@ -3128,7 +3164,7 @@ int i; /* XXX this can take a long time and locking will block packet flow */ - IPFW_LOCK(chain); + IPFW_RLOCK(chain); for (rule = chain->rules; rule ; rule = rule->next) { /* * Verify the entry fits in the buffer in case the @@ -3144,7 +3180,7 @@ bp += i; } } - IPFW_UNLOCK(chain); + IPFW_RUNLOCK(chain); if (ipfw_dyn_v) { ipfw_dyn_rule *p, *last = NULL; @@ -3255,11 +3291,11 @@ * the old list without the need for a lock. */ - IPFW_LOCK(&layer3_chain); + IPFW_WLOCK(&layer3_chain); layer3_chain.reap = NULL; free_chain(&layer3_chain, 0 /* keep default rule */); rule = layer3_chain.reap, layer3_chain.reap = NULL; - IPFW_UNLOCK(&layer3_chain); + IPFW_WUNLOCK(&layer3_chain); if (layer3_chain.reap != NULL) reap_rules(rule); break; @@ -3463,6 +3499,9 @@ int error; layer3_chain.rules = NULL; + layer3_chain.want_write = 0; + layer3_chain.busy_count = 0; + cv_init(&layer3_chain.cv, "Condition variable for IPFW rw locks"); IPFW_LOCK_INIT(&layer3_chain); IPFW_DYN_LOCK_INIT(); callout_init(&ipfw_timeout, debug_mpsafenet ? CALLOUT_MPSAFE : 0); @@ -3536,11 +3575,11 @@ ip_fw_chk_ptr = NULL; ip_fw_ctl_ptr = NULL; callout_drain(&ipfw_timeout); - IPFW_LOCK(&layer3_chain); + IPFW_WLOCK(&layer3_chain); layer3_chain.reap = NULL; free_chain(&layer3_chain, 1 /* kill default rule */); reap = layer3_chain.reap, layer3_chain.reap = NULL; - IPFW_UNLOCK(&layer3_chain); + IPFW_WUNLOCK(&layer3_chain); if (reap != NULL) reap_rules(reap); flush_tables();